import numpy as np
import torch
import random
import math
import torch.nn.functional as F
import matplotlib.pyplot as plt
import re
import torch.nn as nn
import copy
import warnings
import copy

a = torch.tensor([
    [0,  1,  2,  3],
    [1,  1,  2,  3],
    [2,  1,  2,  3],
    [3,  1,  2,  3],
    [0,  1,  2,  3],
])
print('a\n', a)
b = torch.arange(4*5).reshape(5, 4)
print('b\n', b)
c = b[range(len(a))]  # 说明选所有的行
print("c\n", c)
print('a[:, 0]\n', a[:, 0])
d = b[range(len(a)), a[:, 0]]  #
print("d\n", d)  # 输出结果只有len(a)个说明只在每一行选一个而不是每一行都要选a[:, 0]的值
